# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FFT iterface for fast Fourier Transforms using FFTW backend (using pyfftw).
:class:`~hysop.numerics.FftwFFT`
:class:`~hysop.numerics.FftwFFTPlan`
"""
import warnings
import pyfftw
import numpy as np
from hysop import (
__FFTW_NUM_THREADS__,
__FFTW_PLANNER_EFFORT__,
__FFTW_PLANNER_TIMELIMIT__,
__VERBOSE__,
)
from hysop.tools.io_utils import IO
from hysop.tools.htypes import first_not_None
from hysop.tools.misc import prod
from hysop.tools.string_utils import framed_str
from hysop.tools.cache import load_data_from_cache, update_cache
from hysop.numerics.fft.fft import HysopFFTWarning, bytes2str
from hysop.numerics.fft.host_fft import HostFFTPlanI, HostFFTI, HostArray
[docs]
class FftwFFTPlan(HostFFTPlanI):
"""
Build and wraps a FFTW plan.
Emit warnings when SIMD alignment is not used.
Emit warnings when changing input and output alignment.
"""
__FFTW_USE_CACHE__ = True
[docs]
@classmethod
def cache_file(cls):
_cache_dir = IO.cache_path() + "/numerics"
_cache_file = _cache_dir + "/fftw_wisdom.pklz"
return _cache_file
[docs]
@classmethod
def load_wisdom(cls, h):
if cls.__FFTW_USE_CACHE__:
wisdom = load_data_from_cache(cls.cache_file(), h)
if wisdom is not None:
pyfftw.import_wisdom(wisdom)
return True
return False
[docs]
@classmethod
def save_wisdom(cls, h, plan):
if cls.__FFTW_USE_CACHE__:
wisdom = pyfftw.export_wisdom()
update_cache(cls.cache_file(), h, wisdom)
def __init__(self, a, out, scaling=None, **plan_kwds):
verbose = plan_kwds.pop("verbose", __VERBOSE__)
super().__init__(verbose=verbose)
if isinstance(a, HostArray):
plan_kwds["input_array"] = a.handle
else:
plan_kwds["input_array"] = a
if isinstance(out, HostArray):
plan_kwds["output_array"] = out.handle
else:
plan_kwds["output_array"] = out
def fmt_arg(name):
return plan_kwds[name]
def fmt_array(name):
arr = fmt_arg(name)
return "shape={:<16} strides={:<16} dtype={:<16}".format(
str(arr.shape) + ",", str(arr.strides) + ",", str(arr.dtype)
)
title = f" Planning {self.__class__.__name__} "
msg = """ in_array: {}
out_array: {}
axes: {}
direction: {}
threads: {}
flags: {}
planning timelimit: {}""".format(
fmt_array("input_array"),
fmt_array("output_array"),
fmt_arg("axes"),
fmt_arg("direction"),
fmt_arg("threads"),
" | ".join(fmt_arg("flags")),
fmt_arg("planning_timelimit"),
)
if self.verbose:
print()
print(framed_str(title, msg, c="*"))
def hash_arg(name):
return hash(plan_kwds[name])
def hash_array(name):
arr = plan_kwds[name]
return hash(arr.shape) ^ hash(arr.strides)
# h = hash_array('input_array') ^ hash_array('output_array') ^ hash_arg('axes') ^ hash_arg('direction')
h = None
plan = None
may_have_wisdom = self.load_wisdom(h)
if may_have_wisdom:
plan_kwds["flags"] += ("FFTW_WISDOM_ONLY",)
# try to build plan from wisdom only (can fail if wisdom has only measure knowledge for example)
try:
plan = pyfftw.FFTW(**plan_kwds)
except RuntimeError:
pass
if plan is None:
plan_kwds["flags"] = tuple(set(plan_kwds["flags"]) - {"FFTW_WISDOM_ONLY"})
plan = pyfftw.FFTW(**plan_kwds)
self.save_wisdom(h, plan)
if not plan.simd_aligned:
msg = "Resulting plan is not SIMD aligned ({} bytes boundary)."
msg = msg.format(pyfftw.simd_alignment)
warnings.warn(msg, HysopFFTWarning)
self.plan = plan
self.scaling = scaling
self.out = out
self.a = a
@property
def input_array(self):
return self.a
@property
def output_array(self):
return self.out
[docs]
def execute(self):
"""
Execute plan on current input and output array.
"""
self.plan.__call__()
if self.scaling is not None:
self.output_array[...] *= self.scaling
[docs]
def __call__(self):
"""
Execute the plan on possibly different input and output arrays.
Input array updates with arrays that are not aligned on original byte boundary
will result in a copy being made.
Return output array for convenience.
"""
self.execute()
[docs]
class FftwFFT(HostFFTI):
"""
Interface to compute local to process FFT-like transforms using the FFTW backend.
Fftw fft backend has many advantages:
- single, double and long double precision supported
- no intermediate temporary buffers created at each call.
- planning capability with caching
- multithreading capability
Planning destroys initial arrays content.
"""
def __init__(
self,
threads=None,
planner_effort=None,
planning_timelimit=None,
destroy_input=False,
warn_on_misalignment=True,
warn_on_allocation=True,
error_on_allocation=False,
backend=None,
allocator=None,
**kwds,
):
threads = first_not_None(threads, __FFTW_NUM_THREADS__)
planner_effort = first_not_None(planner_effort, __FFTW_PLANNER_EFFORT__)
planning_timelimit = first_not_None(
planning_timelimit, __FFTW_PLANNER_TIMELIMIT__
)
super().__init__(
backend=backend,
allocator=allocator,
warn_on_allocation=warn_on_allocation,
error_on_allocation=error_on_allocation,
**kwds,
)
self.supported_ftypes = (np.float32, np.float64, np.longdouble)
self.supported_ctypes = (np.complex64, np.complex128, np.clongdouble)
self.supported_cosine_transforms = (1, 2, 3, 4)
self.supported_sine_transforms = (1, 2, 3, 4)
self.threads = threads
self.planner_effort = planner_effort
self.planning_timelimit = planning_timelimit
self.destroy_input = destroy_input
self.warn_on_misalignment = warn_on_misalignment
[docs]
@classmethod
def check_alignment(cls, a, out):
"""Check SIMD alignment of input and output arrays."""
msg0 = "{} array is not aligned on SIMD aligment ({} bytes)."
msg0 = msg0.format("{}", pyfftw.simd_alignment)
if (a is not None) and not pyfftw.is_byte_aligned(array=a):
msg = msg0.format("Input")
warnings.warn(msg, HysopFFTWarning)
elif (out is not None) and not pyfftw.is_byte_aligned(out):
msg = msg0.format("Output")
warnings.warn(msg, HysopFFTWarning)
[docs]
def bake_kwds(self, **kwds):
plan_kwds = {}
plan_kwds["a"] = kwds.pop("a")
plan_kwds["out"] = kwds.pop("out")
plan_kwds["direction"] = kwds.pop("direction")
plan_kwds["axes"] = kwds.pop("axes", (kwds.pop("axis"),))
plan_kwds["threads"] = kwds.pop("threads", self.threads)
plan_kwds["verbose"] = kwds.pop("verbose", __VERBOSE__)
plan_kwds["planning_timelimit"] = kwds.pop(
"planning_timelimit", self.planning_timelimit
)
flags = ()
flags += (kwds.pop("planner_effort", self.planner_effort),)
if kwds.pop("destroy_input", self.destroy_input) is True:
flags += ("FFTW_DESTROY_INPUT",)
if kwds.pop("wisdom_only", False) is True:
flags += ("FFTW_WISDOM_ONLY",)
plan_kwds["flags"] = flags
if kwds:
msg = "Unknown keyword arguments: {}"
msg = msg.format(", ".join(f"'{kwd}'" for kwd in kwds.keys()))
raise RuntimeError(msg)
return plan_kwds
[docs]
def fft(self, a, out=None, axis=-1, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype) = super().fft(a=a, out=out, axis=axis, **kwds)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
kwds = self.bake_kwds(a=a, out=out, axis=axis, direction="FFTW_FORWARD", **kwds)
plan = FftwFFTPlan(**kwds)
return plan
[docs]
def ifft(self, a, out=None, axis=-1, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype, s) = super().ifft(a=a, out=out, axis=axis, **kwds)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
kwds = self.bake_kwds(
a=a, out=out, axis=axis, direction="FFTW_BACKWARD", **kwds
)
plan = FftwFFTPlan(**kwds)
return plan
[docs]
def rfft(self, a, out=None, axis=-1, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype) = super().rfft(a=a, out=out, axis=axis, **kwds)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
kwds = self.bake_kwds(a=a, out=out, axis=axis, direction="FFTW_FORWARD", **kwds)
plan = FftwFFTPlan(**kwds)
return plan
[docs]
def irfft(self, a, out=None, n=None, axis=-1, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype, s) = super().irfft(a=a, out=out, axis=axis, n=n, **kwds)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
kwds = self.bake_kwds(
a=a, out=out, axis=axis, direction="FFTW_BACKWARD", **kwds
)
plan = FftwFFTPlan(**kwds)
return plan
[docs]
def dct(self, a, out=None, type=2, axis=-1, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype) = super().dct(a=a, out=out, type=type, axis=axis, **kwds)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
dct_types = ["FFTW_REDFT00", "FFTW_REDFT10", "FFTW_REDFT01", "FFTW_REDFT11"]
direction = dct_types[int(type) - 1]
kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds)
plan = FftwFFTPlan(**kwds)
return plan
[docs]
def idct(self, a, out=None, type=2, axis=-1, scaling=None, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype, itype, s) = super().idct(
a=a, out=out, type=type, axis=axis, **kwds
)
scaling = first_not_None(scaling, 1.0 / s)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
dct_types = ["FFTW_REDFT00", "FFTW_REDFT10", "FFTW_REDFT01", "FFTW_REDFT11"]
direction = dct_types[int(itype) - 1]
kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds)
plan = FftwFFTPlan(scaling=scaling, **kwds)
return plan
[docs]
def dst(self, a, out=None, type=2, axis=-1, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype) = super().dst(a=a, out=out, type=type, axis=axis, **kwds)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
dst_types = ["FFTW_RODFT00", "FFTW_RODFT10", "FFTW_RODFT01", "FFTW_RODFT11"]
direction = dst_types[int(type) - 1]
kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds)
plan = FftwFFTPlan(**kwds)
return plan
[docs]
def idst(self, a, out=None, type=2, axis=-1, scaling=None, **kwds):
"""Planning destroys initial arrays content."""
(shape, dtype, itype, s) = super().idst(
a=a, out=out, type=type, axis=axis, **kwds
)
scaling = first_not_None(scaling, 1.0 / s)
out = self.allocate_output(out, shape, dtype)
if self.warn_on_misalignment:
self.check_alignment(a, out)
dst_types = ["FFTW_RODFT00", "FFTW_RODFT10", "FFTW_RODFT01", "FFTW_RODFT11"]
direction = dst_types[int(itype) - 1]
kwds = self.bake_kwds(a=a, out=out, axis=axis, direction=direction, **kwds)
plan = FftwFFTPlan(scaling=scaling, **kwds)
return plan